import torch
import torchvision.datasets
import torchvision.transforms as transforms

def _balanced_dataset(dataset, data_amount_per_class, num_classes):
    indices = torch.tensor([])
    for c in range(num_classes):
        equals = (dataset.targets == c)
        indices_temp = torch.nonzero(equals)[:,0]
        indices = torch.cat((indices,indices_temp[:data_amount_per_class]),0)
    return torch.utils.data.Subset(dataset, indices[torch.randperm(indices.shape[0])].long())


def load_balanced_dataset(dataset, batch_size, data_amount_per_class, num_classes, shuffle=False, **kwargs):
    dataset = _balanced_dataset(dataset, data_amount_per_class, num_classes)
    return torch.utils.data.DataLoader(dataset, batch_size, shuffle=shuffle, **kwargs)

def load_dataset(root):
    dataset = torch.load(root)
    trainset = dataset["trainset"]
    testset = dataset["testset"]
    return trainset, testset